import argparse
import os
import random

parser = argparse.ArgumentParser()
parser.add_argument(
    "--xml_path",
    default="/run/media/kearney/a/CAU/42course/毕设/datasets/NEU-DET/labels",
    type=str,
    help="input xml label path",
)
parser.add_argument(
    "--txt_path",
    default="/run/media/kearney/a/CAU/42course/毕设/datasets/NEU-DET",
    type=str,
    help="output txt label path",
)
opt = parser.parse_args()

trainval_percent = 0.8  # 训练集、验证集总占比 6 2 2
train_percent = 0.75  # 训练集、验证集中训练集的占比 6/（6+2）
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)

trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)

file_test = open(txtsavepath + "/test.txt", "w")
file_train = open(txtsavepath + "/train.txt", "w")
file_val = open(txtsavepath + "/val.txt", "w")

for i in list_index:
    name = total_xml[i][:-4] + "\n"
    if i in trainval:
        if i in train:
            file_train.write(name)
        else:
            file_val.write(name)
    else:
        file_test.write(name)

file_train.close()
file_val.close()
file_test.close()
